{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2022, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}

# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import os
import sys
import unittest

import numpy as np

# Add the generated package to the path
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

import sym
from symforce import logger


class GeoPackageTest(unittest.TestCase):
    """
    Tests for Python runtime geometry types. Mostly checking basic generation logic
    since the math is tested comprehensively in symbolic form.
    """

    def setUp(self):
        # type: () -> None
        np.random.seed(42)
        # Store verbosity flag so tests can use
        self.verbose = ("-v" in sys.argv) or ("--verbose" in sys.argv)

    {% for cls in all_types %}
    def test_storage_ops_{{ cls.__name__ }}(self):
        # type: () -> None
        """
        Tests:
            sym.{{ cls.__name__ }} StorageOps
        """

        geo_class = sym.{{ cls.__name__ }}
        logger.debug("*** Testing StorageOps: %s ***", geo_class.__name__)

        value = geo_class()
        self.assertEqual(len(value.data), geo_class.storage_dim())

        vec = value.to_storage()
        self.assertTrue(len(vec) > 0)
        self.assertEqual(len(vec), geo_class.storage_dim())
        for i, element in enumerate(vec):
            self.assertEqual(element, value.data[i])

        value2 = geo_class.from_storage(vec)
        self.assertEqual(value.data, value2.data)
        vec[0] = 2.1
        value3 = geo_class.from_storage(vec)
        self.assertNotEqual(value.data, value3.data)

    def test_group_ops_{{ cls.__name__ }}(self):
        # type: () -> None
        """
        Tests:
            sym.{{ cls.__name__ }} GroupOps
        """
        geo_class = sym.{{ cls.__name__ }}
        group_ops = sym.ops.{{ cls.__name__.lower() }}.GroupOps
        logger.debug("*** Testing GroupOps: %s ***", geo_class.__name__)

        identity = geo_class()

        # TODO(Nathan): Consider reorganizing how the generated python geo package is structured so that
        # each class doesn't have to use helper functions to call the underlying group_ops functions
        # Example using the underlying group_ops implementation:
        self.assertEqual(identity, group_ops.identity())

        # Example using the helper functions:
        self.assertEqual(identity, geo_class.identity())
        self.assertEqual(identity, identity.compose(identity))
        self.assertEqual(identity, identity.inverse())
        self.assertEqual(identity, identity.between(identity))

    def test_lie_group_ops_{{ cls.__name__ }}(self):
        # type: () -> None
        """
        Tests:
            sym.{{ cls.__name__ }} LieGroupOps
        """

        geo_class = sym.{{ cls.__name__ }}
        logger.debug("*** Testing LieGroupOps: %s ***", geo_class.__name__)

        tangent_dim = geo_class.tangent_dim()
        self.assertTrue(tangent_dim > 0)
        self.assertTrue(tangent_dim <= geo_class.storage_dim())

        perturbation = np.random.rand(tangent_dim)
        value = geo_class.from_tangent(perturbation)
        recovered_perturbation = geo_class.to_tangent(value)
        np.testing.assert_almost_equal(perturbation, recovered_perturbation)

        identity = geo_class.identity()
        recovered_identity = value.retract(-recovered_perturbation)
        np.testing.assert_almost_equal(recovered_identity.to_storage(), identity.to_storage())

        perturbation_zero = identity.local_coordinates(recovered_identity)
        np.testing.assert_almost_equal(perturbation_zero, np.zeros(tangent_dim))

        np.testing.assert_almost_equal(
            identity.interpolate(value, 0.0).to_storage(), identity.to_storage()
        )
        np.testing.assert_almost_equal(
            identity.interpolate(value, 1.0).to_storage(), value.to_storage()
        )

    def test_custom_methods_{{ cls.__name__ }}(self):
        # type: () -> None
        """
        Tests:
            sym.{{ cls.__name__ }} custom methods
        """
        geo_class = sym.{{ cls.__name__ }}
        logger.debug("*** Testing Custom Methods: %s ***", geo_class.__name__)

        tangent_dim = geo_class.tangent_dim()
        element = geo_class.from_tangent(np.random.normal(size=tangent_dim))

        {% set dim = cls.__name__[-1] %}
        vector = np.random.normal(size=({{ dim }}, 1))
        {% if "Rot" in cls.__name__ %}
        matrix = element.to_rotation_matrix()
        np.testing.assert_almost_equal(np.matmul(matrix, vector), element * vector)

        # Test constructor handles column vectors correctly
        col_data = np.random.normal(size=(geo_class.storage_dim(), 1))
        rot = geo_class(col_data)
        expected_data = col_data.ravel().tolist()
        self.assertEqual(expected_data, rot.data)
        for x in rot.data:
            # NOTE(brad): One failure mode is for x to not be a float but an ndarray.
            # This isn't caught by the above because [np.array([1.0])] == [1.0]
            # evaluates to True.
            self.assertIsInstance(x, float)

        # Test constructor raises a IndexError if input is too large
        with self.assertRaises(IndexError):
            geo_class([1, 2, 3, 4, 5, 6])
        {% else %}
        vector_as_element = geo_class(t=vector.ravel().tolist())
        np.testing.assert_almost_equal(element * vector.ravel(), (element * vector_as_element).position())

        # Test position/rotation accessors
        np.testing.assert_equal(element.position(), element.t)
        self.assertEqual(element.rotation(), element.R)

        # Test position/rotation accessors are compatible with Constructor
        element_copy = geo_class(t=element.t, R=element.R)
        self.assertEqual(element.data, element_copy.data)
        for x in element_copy.data:
            # NOTE(brad): One failure mode is for x to not be a float but an ndarray.
            # This isn't caught by the above because [np.array([1.0])] == [1.0]
            # evaluates to True.
            self.assertIsInstance(x, float)

        # Test constructor handles column vectors correctly.
        column_t = np.expand_dims(np.array(element.t).ravel(), axis=1)
        column_element = geo_class(t=column_t, R=element.R)
        self.assertEqual(element.data, column_element.data)
        for x in column_element.data:
            self.assertIsInstance(x, float)

        # Test constructor raises a IndexError if input t is too large
        with self.assertRaises(IndexError):
            geo_class(t=[1, 2, 3, 4, 5, 6])

        # Test constructor raises a ValueError if a non Rot is passed for R
        with self.assertRaises(ValueError):
            geo_class(R=4)  # type: ignore[arg-type]
        {% endif %}

    {% endfor %}

if __name__ == "__main__":
    unittest.main()
